Python is an object oriented scripting language and does not require a specific first or last line (such as public static void main
in Java or return
in C).
There are no curly braces {} to define code blocks or semi-colons ; to end a line. Instead of braces, indentation is rigidly enforced to create a block of code.
In [9]:
# This is a comment
if (3 < 2):
print "True" # Another Comment. This print syntax only works in Python 2, not 3
else:
print "False"
Arbitrary indentation can be used within a code block, as long as the indentation is consistent.
In [10]:
if (1 == 1):
print "We're in "
print "Deep Trouble:"
In [11]:
if (0 > -1):
print "This works "
print "just fine."
Variables can be given alphanumeric names beginning with an underscore or letter. Variable types do not have to be declared and are inferred at run time.
In [12]:
a = 1
print type(a) # Built in function
In [13]:
b = 2.5
print type(b)
Strings can be declared with either single or double quotes.
In [14]:
c1 = "Go "
c2 = 'Gators'
c3 = c1 + c2
print c3
print type(c3)
The scope of variables is local to the function, class, and file in that increasing order of scope. Global variables can also be declared.
In [15]:
print "b used to be", b # Prints arguments with a space separator
# Our first function declaration
def sum():
global b
b = a + b
sum() # calling sum
# using this syntax, the arguments can be of any type that supports a string representation. No casting needed.
print "Now b is", b
In [16]:
# To use Math, we must import it
import math
print cos(0)
Whoops. Importing the math
module allows us access to all of its functions, but we must call them in this way
In [17]:
print math.cos(0)
Alternatively, you can use the from
keyword
In [18]:
from math import cos
print cos(math.pi) # we only imported cos, not the pi constant
Using the from
statement we can import everything from the math module.
Disclaimer: many Pythonistas discourage doing this for performance reasons. Just import what you need
In [19]:
from math import *
print sin(pi/2) # now we don't have to make a call to math
In [20]:
mystring = "Go Gators, Come on Gators, Get up and go!"
print mystring[11:25]
Python is a 0-index based language. Generally whenever forming a range of values in Python, the first argument is inclusive whereas the second is not, i.e. mystring[11:25]
returns characters 11 through 24.
You can omit the first or second argument
In [21]:
print mystring[:9] # all characters before the 9th index
In [22]:
print mystring[27:] # all characters at or after the 27th
In [23]:
print mystring[:] # you can even omit both arguments
Using negative values, you can count positions backwards
In [24]:
print mystring[-3:-1]
In [25]:
print mystring.find("Gators") # returns the index of the first occurence of Gators
In [26]:
print mystring.find("Gators", 4) # specify an index on which to begin searching
In [27]:
print mystring.find("Gators", 4, 19) # specify both begin and end indexes to search
Looks like nothing was found. -1 is returned by default.
In [28]:
print mystring.find("Seminoles") # no Seminoles here
In [29]:
print mystring.lower()
print mystring.upper()
In [30]:
print mystring.replace("Gators", "Seminoles") # replaces all occurences of Gators with Seminoles
In [31]:
print mystring
Notice that replace returned a new string. Nothing was modified in place
In [32]:
print mystring.replace("Gators", "Seminoles", 1) # limit the number of replacements
In [33]:
print mystring.split() # returns a list of strings broken by a space by default
In [34]:
print mystring.split(',') # you can also define the separator
In [35]:
print ' '.join(["Go", "Gators"])
For more information on string functions:
https://docs.python.org/2/library/stdtypes.html#string-methods
In [36]:
mylist = [1, 2, 3, 4, 'five']
print mylist
In [37]:
mylist.append(6.0) # add an item to the end of the list
print mylist
In [38]:
mylist.extend([8, 'nine']) # extend the list with the contents of another list
print mylist
In [39]:
mylist.insert(6, 7) # insert the number 7 at index 6
print mylist
In [40]:
mylist.remove('five') # removes the first matching occurence
print mylist
In [41]:
popped = mylist.pop() # by default, the last item in the list is removed and returned
print popped
print mylist
In [42]:
popped2 = mylist.pop(4) # pops at at index
print popped2
print mylist
In [43]:
print len(mylist) # returns the length of any iterable such as lists and strings
In [44]:
mylist.extend(range(-3, 0)) # the range function returns a list from -3 inclusive to 0 non inclusive
print mylist
In [45]:
# default list sorting. When more complex objects are in the list, arguments can be used to customize how to sort
mylist.sort()
print mylist
In [46]:
mylist.reverse() # reverse the list
print mylist
For more information on Lists:
https://docs.python.org/2/tutorial/datastructures.html#more-on-lists
Python supports n-tuple sequences. These are non-mutable
In [47]:
mytuple = 'Tim', 'Tebow', 15 # Created with commas
print mytuple
print type(mytuple)
In [48]:
print mytuple[1] # access an item
In [49]:
mytuple[1] = "Winston" # results in error
In [50]:
schools = ['Florida', 'Florida State', 'Miami', 'Florida']
myset = set(schools) # the set is built from the schools list
print myset
In [51]:
print 'Georgia' in myset # membership test
In [52]:
print 'Florida' in myset
In [53]:
badschools = set(['Florida State', 'Miami'])
print myset - badschools # set arithmetic
In [54]:
print myset & badschools # AND
In [55]:
print myset | set(['Miami', 'Stetson']) # OR
In [56]:
print myset ^ set(['Miami', 'Stetson']) # XOR
In [57]:
mydict = {'Florida' : 1, 'Georgia' : 2, 'Tennessee' : 3}
print mydict
In [58]:
print mydict['Florida'] # access the value with key = 'Florida'
In [59]:
del mydict['Tennessee'] # funky syntax to delete a key, value pair
print mydict
In [60]:
mydict['Georgia'] = 7 # assignment
print mydict
In [61]:
mydict['Kentucky'] = 6 # you can append a new key
print mydict
In [62]:
print mydict.keys() # get a list of keys
In [63]:
a = 2; b = 1;
if a > b: print "a is greater than b"
In [64]:
if b > a:
print "b is greater than a"
else:
print "b is less than or equal to a"
In [65]:
b = 2
if a > b:
print "a is greater than b"
elif a < b:
print "a is less than b"
else:
print "a is equal to b"
In [66]:
for x in range(10): # with one argument, range produces integers from 0 to 9
print x
In [67]:
for y in range(5, 12): # with two argumentts, range produces integers from 5 to 11
print y
In [68]:
for z in range(1, 12, 3): # with three arguments, range starts at 1 and goes in steps of 3 until greater than 12
print z
In [69]:
for a in range(10, 1, -5): # can use a negative step size as well
print a
In [70]:
for b in range(2, 1, 1): # with a positive step, all values are less than 1. No integers are produced
print b
In [71]:
for c in range(1, 2, -1): # same goes for a negative step as all values are less than 2
print c
In [72]:
for i in ['foo', 'bar']: # iterate over a list of strings
print i
In [73]:
anotherdict = {'one' : 1, 'two' : 2, 'three' : 3}
for key in anotherdict.keys(): # iterate over a dictionary. Order is not guaranteed
print key, anotherdict[key]
In [74]:
a = 1; b = 4; c = 7; d = 5;
while (a < b) and (c > d): # example of and condition
print c - a
a += 1 # example of incrementing
c -= 1 # decrementing
Python does not have a construct for a do-while loop, though it can be accomplished using the break
statement
In [75]:
a = 1; b = 10
while True: # short circuit the while condition
a *= 2
print a
if a > b:
break
Functions in Python do not have a distinction between those that do and do not return a value. If a value is returned, the type is not declared.
Functions can be declared in any module without any distinction between static and non-static. Functions can even be declared within other functions
The syntax is the following
In [76]:
def hello():
print "Hello there!"
hello()
In [77]:
def player(name, number): # use some arguments
print "#" + str(number), name # cast number to a string when concatenating
player("Kasey Hill", 0)
Functions can have optional arguments if a default value is provided in the function signature
In [78]:
def player(name, number, team = 'Florida'): # optional team argument
print "#" + str(number), name, team
player("Kasey Hill", 0) # no team argument supplied
In [79]:
player("Aaron Harrison", 2, "Kentucky") # supplying all three arguments
Python functions can be called using named arguments, instead of positional
In [80]:
player(number = 23, name = 'Chris Walker')
In [81]:
args = ['Michael Frazier II', 20, 'Florida']
player(*args) # calling player with the dereferenced argument list
Argument lists can also be used in defining a function as such
In [82]:
def foo(*args):
for someFoo in args:
print someFoo
foo('la', 'dee', 'da') # supports an arbitrary number of arguments
In [83]:
kwargs = {'name' : 'Michael Frazier II', 'number' : 20}
player(**kwargs) # calling player with the dereferenced kwargs dictionary. The team argument will be defaulted
Just as before, we can define a function taking an arbitrary dictionary
In [84]:
def foo(**kwargs):
for key in kwargs.keys():
print key, kwargs[key]
foo(**kwargs)
In [85]:
def sum(x,y):
return x + y # return a single value
print sum(1,2)
In [86]:
def sum_and_product(x,y):
return x + y, x * y # return two values
mysum, myproduct = sum_and_product(1,2)
print mysum, myproduct
Now that we've covered some Python basics, we will begin a tutorial going through many tasks a data scientist may perform. We will obtain real world data and go through the process of auditing, analyzing, visualing, and building classifiers from the data.
We will use a database of breast cancer data obtained from the University of Wisconsin Hospitals, Madison from Dr. William H. Wolberg. The data is a collection of samples from Dr. Wolberg's clinical cases with attributes pertaining to tumors and a class labeling the sample as benign or malignant.
Attribute | Domain |
---|---|
1. Sample code number | id number |
2. Clump Thickness | 1 - 10 |
3. Uniformity of Cell Size | 1 - 10 |
4. Uniformity of Cell Shape | 1 - 10 |
5. Marginal Adhesion | 1 - 10 |
6. Single Epithelial Cell Size | 1 - 10 |
7. Bare Nuclei | 1 - 10 |
8. Bland Chromatin | 1 - 10 |
9. Normal Nucleoli | 1 - 10 |
10. Mitoses | 1 - 10 |
11. Class | (2 for benign, 4 for malignant) |
For more information on this data set: https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Diagnostic%29
Lets begin by programmatically obtaining the data. Here I'll define a function we can use to make HTTP requests and download the data
In [87]:
def download_file(url, local_filename):
import requests
# stream = True allows downloading of large files; prevents loading entire file into memory
r = requests.get(url, stream = True)
with open(local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
f.flush()
Now we'll specify the url of the file and the file name we will save to
In [88]:
url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data'
filename = 'breast-cancer-wisconsin.csv'
And make a call to download_file
In [89]:
download_file(url, filename)
Now this might seem like overkill for downloading a single, small csv file, but we can use this same function to access countless APIs available on the World Wide Web by building an API request in the url.
Now that we have some data, lets get it into a useful form. For this task we will use a package called pandas. pandas is an open source, BSD-licensed library providing high-performance, easy-to-use data structures and data analysis tools for Python. The most fundamental data structure in pandas is the dataframe, which is similar to the data.frame data structure found in the R statistical programming language.
For more information: http://pandas.pydata.org
pandas dataframes are a 2-dimensional labeled data structures with columns of potentially different types. Dataframes can be thought of as similar to a spreadsheet or SQL table.
There are numerous ways to build a dataframe with pandas. Since we have already attained a csv file, we can use a parser built into pandas called read_csv
which will read the contents of a csv file directly into a data frame.
For more information: http://pandas.pydata.org/pandas-docs/dev/generated/pandas.io.parsers.read_csv.html
In [90]:
import pandas as pd # import the module and alias it as pd
cancer_data = pd.read_csv('breast-cancer-wisconsin.csv')
cancer_data.head() # show the first few rows of the data
Out[90]:
Whoops, looks like our csv file did not contain a header row. read_csv
assumes the first row of the csv is the header by default.
Lets check out the file located here: https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.names
This contains information about the data set including the names of the attributes.
Lets create a list of these attribute names to use when reading the csv file
In [91]:
# \ allows multi line wrapping
cancer_header = [ \
'sample_code_number', \
'clump_thickness', \
'uniformity_cell_size', \
'uniformity_cell_shape', \
'marginal_adhesion', \
'single_epithelial_cell_size', \
'bare_nuclei', \
'bland_chromatin', \
'normal_nucleoli', \
'mitoses', \
'class']
Lets try the import again, this time specifying the names. When specifying names, the read_csv
function requires us to set the header
row number to None
In [92]:
cancer_data = pd.read_csv('breast-cancer-wisconsin.csv', header=None, names=cancer_header)
cancer_data.head()
Out[92]:
Lets take a look at some simple statistics for the clump_thickness column
In [93]:
cancer_data["clump_thickness"].describe()
Out[93]:
Referring to the documentation link above about the data, the count, range of values (min = 1, max = 10), and data type (dtype = float64) look correct.
Lets take a look at another column, this time bare_nuclei
In [94]:
cancer_data["bare_nuclei"].describe()
Out[94]:
Well at least the count is correct. We were expecting no more than 10 unique values and now the data type is an object.
Whats up with our data?
We have arrived at arguably the most important part of performing data science: dealing with messy data. One of most important tools in a data scientist's toolbox is the ability to audit, clean, and reshape data. The real world is full of messy data and your sources may not always have data in the exact format you desire.
In this case we are working with csv data, which is a relatively straightforward format, but this will not always be the case when performing real world data science. Data comes in all varieties from csv all the way to something as unstructured as a collection of emails or documents. A data scientist must be versed in a wide variety of technologies and methodologies in order to be successful.
Now, lets do a little bit of digging into why were are not getting a numeric pandas column
In [95]:
cancer_data["bare_nuclei"].unique()
Out[95]:
Using unique
we can see that '?' is one of the distinct values that appears in this series. Looking again at the documentation for this data set, we find the following:
Missing attribute values: 16
There are 16 instances in Groups 1 to 6 that contain a single missing (i.e., unavailable) attribute value, now denoted by "?".
It was so nice of them to tell us to expect these missing values, but as a data scientist that will almost never be the case. Lets see what we can do with these missing values.
In [96]:
cancer_data["bare_nuclei"] = cancer_data["bare_nuclei"].convert_objects(convert_numeric=True)
Here we have attempted to convert the bare_nuclei series to a numeric type. Lets see what the unique values are now.
In [97]:
cancer_data["bare_nuclei"].unique()
Out[97]:
The decimal point after each number means that it is an integer value being represented by a floating point number. Now instead of our pesky '?' we have nan
(not a number). nan
is a construct used by pandas to represent the absence of value. It is a data type that comes from the package numpy, used internally by pandas, and is not part of the standard Python library.
Now that we have nan
values in place of '?', we can use some nice features in pandas to deal with these missing values.
What we are about to do is what is called "imputing" or providing a replacement for missing values so the data set becomes easier to work with. There are a number of strategies for imputing missing values, all with their own pitfalls. In general, imputation introduces some degree of bias to the data, so the imputation strategy taken should be in an attempt to minimize that bias.
Here, we will simply use the mean of all of the non-nan values in the series as a replacement. Since we already know that the data is integer in possible values, we will round the mean to the nearest whole number.
In [98]:
cancer_data.fillna(cancer_data.mean().round(), inplace=True)
cancer_data["bare_nuclei"].unique()
Out[98]:
fillna
is a dataframe function that replaces all nan values with either a scalar value, a series of values with the same indices as found in the dataframe, or a dataframe that is indexed by the columns of the target dataframe.
cancer_data.mean().round()
will take the mean of each column (this computation ignores the currently present nan values), then round, and return a dataframe indexed by the columns of the original dataframe:
In [99]:
cancer_data.mean().round()
Out[99]:
inplace=True
allows us to make this modification directly on the dataframe, without having to do any assignment.
Now that we have figured out how to impute these missing values in a single column, lets start over and quickly apply this technique to the entire dataframe.
In [100]:
cancer_data = pd.read_csv('breast-cancer-wisconsin.csv', header=None, names=cancer_header)
cancer_data = cancer_data.convert_objects(convert_numeric=True)
cancer_data.fillna(cancer_data.mean().round(), inplace=True)
cancer_data["bare_nuclei"].describe()
Out[100]:
In [101]:
cancer_data["bare_nuclei"].unique()
Out[101]:
Structurally, Pandas dataframes are a collection of Series objects sharing a common index. In general, the Series object and Dataframe object share a large number of functions with some behavioral differences. In other words, whatever computation you can do on a single column can generally be applied to the entire dataframe.
Now we can use the dataframe version of describe
to get an overview of all of our data
In [102]:
cancer_data.describe()
Out[102]:
Another important tool in the data scientist's toolbox is the ability to create visualizations from data. Visualizing data is often the most logical place to start getting a deeper intuition of the data. This intuition will shape and drive your analysis.
Even more important than visualizing data for your own personal benefit, it is often the job of the data scientist to use the data to tell a story. Creating illustrative visuals that succinctly convey an idea are the best way to tell that story, especially to stakeholders with less technical skillsets.
Here we will be using a Python package called ggplot (https://ggplot.yhathq.com). The ggplot package is an attempt to bring visuals following the guidelines outlayed in the grammar of graphics (http://vita.had.co.nz/papers/layered-grammar.html) to Python. It is based off of and intended to mimic the features of the ggplot2 library found in R. Additionally, ggplot is designed to work with Pandas dataframes, making things nice and simple.
We'll start by doing a bit of setup
In [103]:
# The following line is NOT Python code, but a special syntax for enabling inline plotting in IPython
%matplotlib inline
from ggplot import *
import warnings
# ggplot usage of pandas throws a future warning
warnings.filterwarnings('ignore')
So we enabled plotting in IPython and imported everything from the ggplot package. Now we'll create a plot and then break down the components
In [104]:
plt = ggplot(aes(x = 'clump_thickness'), data = cancer_data) + \
geom_histogram(binwidth = 1, fill = 'steelblue')
# using print gets the plot to show up here within the notebook.
# In normal Python environment without using print, the plot appears in a window
print plt
A plot begins with the ggplot
function. Here, we pass in the cancer_data pandas dataframe and a special function called aes
(short for aesthetic). The values provided to aes
change depending on which type of plot is being used. Here we are going to make a histogram from the clump_thickness column in cancer_data, so that column name needs to be passed as the x parameter to aes
.
The grammar of graphics is based off of a concept of "geoms" (short for geometric objects). These geoms provide granular control of the plot and are progressively added to the base call to ggplot
with + syntax.
Lets say we wanted to show the mean clump_thickness on this plot. We could do something like the following
In [105]:
plt = ggplot(aes(x = 'clump_thickness'), data = cancer_data) + \
geom_histogram(binwidth = 1, fill = 'steelblue') + \
geom_vline(xintercept = [cancer_data['clump_thickness'].mean()], linetype='dashed')
print plt
As you can see, each geom has its own set of parameters specific to the appearance of that geom (also called aesthetics).
Lets try a scatter plot to get some multi-variable action
In [106]:
plt = ggplot(aes(x = 'uniformity_cell_shape', y = 'bare_nuclei'), data = cancer_data) + \
geom_point()
print plt
Sometimes when working with integer data, or data that takes on a limited range of values, it is easier to visualize the plot with added jitter to the points. We can do that by adding an aesthetic to geom_point
.
In [107]:
plt = ggplot(aes(x = 'uniformity_cell_shape', y = 'bare_nuclei'), data = cancer_data) + \
geom_point(position = 'jitter')
print plt
With a simple aesthetic addition, we can see how these two variables play into our cancer classification
In [108]:
plt = ggplot(aes(x = 'uniformity_cell_shape', y = 'bare_nuclei', color = 'class'), data = cancer_data) + \
geom_point(position = 'jitter')
print plt
By adding color = 'class'
as a parameter to the aes function, we now give a color to each unique value found in that column and automatically get a legend. Remember, 2 is benign and 4 is malignant.
We can also do things such as add a title or change the axis labeling with geoms
In [109]:
plt = ggplot(aes(x = 'uniformity_cell_shape', y = 'bare_nuclei', color = 'class'), data = cancer_data) + \
geom_point(position = 'jitter') + \
ggtitle("The Effect of the Bare Nuclei and Cell Shape Uniformity on Classification") + \
ylab("Amount of Bare Nuclei") + \
xlab("Uniformity in Cell shape")
print plt
There is definitely some patterning going on in that plot.
A slightly different way to convey this idea is to use faceting. Faceting is the creation of multiple related plots arranged by the values of a given faceted variable
In [110]:
plt = ggplot(aes(x = 'uniformity_cell_shape', y = 'bare_nuclei'), data = cancer_data) + \
geom_point(position = 'jitter') + \
ggtitle("The Effect of the Bare Nuclei and Cell Shape Uniformity on Classification") + \
facet_grid('class')
print plt
Rather than set the color equal to the class, we have created two plots based off of the class. With a facet, we can get very detailed. Lets through some more variables into the mix
In [111]:
plt = ggplot(aes(x = 'uniformity_cell_shape', y = 'bare_nuclei', color = 'class'), data = cancer_data) + \
geom_point(position = 'jitter') + \
ggtitle("The Effect of the Bare Nuclei and Cell Shape Uniformity on Classification") + \
facet_grid('clump_thickness', 'marginal_adhesion')
print plt
Unfortunately, legends for faceting are not yet implemented in the Python ggplot package. In this example we faceted on the x-axis with clump_thickness and along the y-axis with marginal_adhesion, then created 100 plots of uniformity_cell_shape vs. bare_nuclei effect on class.
I highly encourage you to check out https://ggplot.yhathq.com/docs/index.html to see all of the available geoms. The best way to learn is to play with and visualize the data with many different plots and aesthetics.
So now that we've acquired, audited, cleaned, and visualized our data, we have arrived at machine learning. By formal definition from Tom Mitchell:
A computer program is set to learn from an experience E with respect to some task T and some performance measure P if its performance on T as measured by P improves with experience E.
Okay, thats a bit ridiculous. Essentially machine learning is the science of building algorithms that learn from data in order make predictions about the data. There are two main classes of machine learning: supervised and unsupervised.
In supervised learning, an algorithm will use the features of the data given to make a prediction about a known label. For example, we will use supervised learning here to take features such as bare_nuclei and uniformity_cell_shape and predict a tumor class (benign or malignant). This type of machine learning is called supervised because the class labels (benign or malignant) are a known quantity during learning, so we are supervising the algorithm with the "correct" answer.
In unsupervised learning, an algorithm will use the features of the data to discover what types of labels there could be. The "correct" answer is not known.
In this session we will be mostly focused on supervised learning as we attempt to predict whether a tumor is benign or malignant. We will also be focused on doing some practical machine learning, and will glaze over the algorithmic details.
The first thing we have to do is to extract the class labels and features from cancer_data
and store them as separate arrays. In our first classifier we will only choose two features from cancer_data
to keep things simple
In [ ]:
cancer_features = ['uniformity_cell_shape', 'bare_nuclei']
Here we call values
on the dataframe to extract the values stored in the dataframe as an array of numpy arrays with the same dimensions as our subsetted dataframe. Numpy is a powerful, high performance scientific computing package that implements arrays. It is used internally by pandas. We will use labels
and features
later on in our machine learning classifier
In [118]:
labels = cancer_data['class'].values
features = cancer_data[cancer_features].values
An important concept in machine learning is to split the data set into training data and testing data. The machine learning algorithm will use the subset of training data to build a classifier to predict labels. We then test the accuracy of this classifier on the subset of testing data. This is done in order to prevent overfitting the classifier to one given set of data.
Overfitting is a major concern in the design of machine learning algorithms. Conceptually, overfitting is when a classifier is really good at predicting the data used to build it, but isn't robust or general enough to predict new, yet unseen data all that well.
To perform machine learning, we will use a package called sci-kit learn (sklearn for short). The sklearn cross_validation module contains a function called train_test_split
that will take in features and labels, and randomly select values into the training and testing subsets
In [113]:
from sklearn.cross_validation import train_test_split
features_train, features_test, labels_train, labels_test = train_test_split(features,
labels,
test_size = 0.3,
random_state = 42)
For this example, we will build a Decision Tree Classifier. The goal of a decision tree is to create a prediction by outlining a simple tree of decision rules. These rules are built from the training data by slicing the data on simple boundaries and trying to minimize the prediction error of that boundary. More details on decision trees can be found here: http://scikit-learn.org/stable/modules/tree.html
The first step is to import the classifier from the sklearn.tree
module.
In [ ]:
from sklearn.tree import DecisionTreeClassifier
Next, we create a variable to store the classifier
In [ ]:
clf = DecisionTreeClassifier()
Then we have to fit the classifier to the training data. Both the training features (uniformity_cell_shape and bare_nuclei) and the labels (benign vs. malignant) are passed to the fit function
In [ ]:
clf.fit(features_train, labels_train)
The classifier is now ready to make some predictions. We can use the score function to see how accurate the classifier is on the test data. The score function will take the data in features_test
, make a prediction of benign or malignant based on the decision tree that was fit to the training data, and compare that prediction to the true values in labels_test
In [114]:
print "Accuracy score:", clf.score(features_test,labels_test)
Nearly all classifiers, decision trees included, will have paremeters that can be tuned to build a more accurate model. Without any parameter tuning and using just two features we have made a pretty accurate prediction. Good job!
To get a better idea of what is going on, I have included a helper function to plot our test data along with the decision boundary
In [117]:
from class_vis import prettyPicture # helper class
prettyPicture(clf, features_test, labels_test)
The area in red is where the classifier predicts a malignant tumor, whereas the blue area predicts a benign tumor. The color of the points on the plot represents the true label of the data point. Remember, there is no jitter included in this plot, so a number of data points are plotted on top of one another.
The vertical and horizontal lines represent what is called the decision boundary. For example, our classifier predicts all data points with uniformity_cell_shape greater than around 6.25 to be malignant.